Skip to content

[fix][train] Prompt-based mini-batching for step-wise training#1529

Merged
CharlieFRuan merged 5 commits intomainfrom
prompt-based-mini-batching-v3
Apr 17, 2026
Merged

[fix][train] Prompt-based mini-batching for step-wise training#1529
CharlieFRuan merged 5 commits intomainfrom
prompt-based-mini-batching-v3

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 17, 2026

Summary

Step-wise training decomposes multi-turn trajectories into one training sequence per LLM turn, producing a variable number of sequences per prompt. This broke the old fixed-size mini-batching in two ways:

  1. Crash: Total sequence count wasn't always divisible by the fixed mini-batch size.
  2. More optimizer steps than intended: more turns → more mini-batches → more optimizer steps per training batch.

This PR shifts mini-batching from sequence units to prompt units. Each mini-batch now contains sequences for exactly policy_mini_batch_size prompts, regardless of how many sequences those prompts generated. This ensures the number of optimizer steps is always train_batch_size / policy_mini_batch_size * update_epochs_per_batch.

The overhead should be minimal, since the number of padded sequence is capped at dp_size for each mini batch.

Key changes

  • compute_prompt_mini_batch_boundaries() (skyrl/train/dataset/preprocess.py): walks a flat uids list, detects prompt boundaries by consecutive-equal groups, and slices them into (start, end) boundary pairs for each mini-batch. Asserts uid contiguity (a uid cannot re-appear after a gap). Asserts len(unique_uids) == train_batch_size. For non-step-wise, asserts boundaries are uniform (backward compatible).
  • MeshDispatch.stage_chunks() (dispatch.py): accepts mini_batch_boundaries instead of computing fixed-size chunks. Each mini-batch is individually padded to dp_size using pad_training_input_batch().
  • _normalize_advantages() and _execute_training_step() (trainer.py): iterate over boundary pairs instead of fixed-size slicing.
  • apply_loss_reduction_to_advantages_minibatch() (ppo_utils.py): will not support token_mean_legacy for now since num_micro_batches depend on how it is padded
  • WorkerDispatch.stage_data() (worker_dispatch.py): passes boundaries through to stage_chunks.

Backward compatibility

For non-step-wise training, where each prompt has exactly n_samples_per_prompt sequences, boundaries remain uniform — identical to the original fixed-size slicing. An assertion in compute_prompt_mini_batch_boundaries verifies this.

Test plan

  • tests/train/test_prompt_mini_batch.py: unit tests for compute_prompt_mini_batch_boundaries (non-step-wise, step-wise, contiguity assertion, boundary uniformity parametrized), MeshDispatch.stage_chunks (padding, loss_mask zeros, variable sizes), and optimizer step count invariance.
  • SearchR1 step-wise 4-turn training run with the new code on wandb project skyrl-search-padding.

Search-r1 Curves

Report link: https://api.wandb.ai/links/sky-posttraining-uc-berkeley/c43eauat

1. Comparing non-stepwise is not affected before and after this PR, and step-wise vs. non-step-wise for this PR

image

2. Step-wise across PRs

Analysis:

image

Commands used

On 8xH100s.

Non-stepwise:

USE_CONVERSATION_MULTI_TURN=true bash examples/train/search/run_search.sh \
  generator.inference_engine.num_engines=8 \
  generator.inference_engine.tensor_parallel_size=1

Step-wise:

USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \
  generator.inference_engine.num_engines=8 \
  generator.inference_engine.tensor_parallel_size=1

🤖 Generated with Claude Code

gemini-code-assist[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 5 additional findings.

Open in Devin Review

@CharlieFRuan CharlieFRuan changed the title Prompt based mini batching v3 [fix][train] Prompt-based mini-batching for step-wise training Apr 17, 2026
@CharlieFRuan CharlieFRuan merged commit c7cb2a5 into main Apr 17, 2026
6 checks passed
@CharlieFRuan CharlieFRuan deleted the prompt-based-mini-batching-v3 branch April 17, 2026 19:43
CharlieFRuan added a commit that referenced this pull request Apr 19, 2026
Rebase PR #1479 onto current main (post-PRs #1507/#1526/#1527/#1529).
The original E2E fix's `pad_batch` change is dropped since #1529's
prompt-based mini-batch boundaries removed the need to pad to
`mini_batch_size * n_samples`.

- merge_stepwise_output() in trainer_utils.py collapses multi-turn
  step-wise GeneratorOutput sequences into single sequences when
  consecutive turns share a common prefix, reducing training cost
  from O(T^2) to O(T).
- trainer.py: call merge before extracting generator fields, update
  uids from merged trajectory_ids, emit generate/num_seq_{before,after}_merge.
- Add generator.merge_stepwise_output config flag.
- run_search.sh: MERGE_STEPWISE env var.
- 16 CPU-only tests covering all 3 merging cases, partial merges, and
  validation asserts.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CharlieFRuan added a commit that referenced this pull request Apr 20, 2026
…1532)

## Summary

This PR implements prefix-aware merging for step-wise training, guarded
by a flag `cfg.generator.merge_stepwise_output` that defaults to False.
During step-wise training, within a trajectory, when consecutive steps
share the same prefix (i.e. no re-tokenization drift or context
management like thinking token stripping), we collapse into a single
`GeneratorOutput` entry. This can reduce the O(T²) training cost
introduced by step-wise (T being number of turns).

- `merge_stepwise_output()` in `skyrl/train/utils/trainer_utils.py`
implements greedy merging: for consecutive turns in the same trajectory
where `prompt[i] + response[i]` is a prefix of `prompt[i+1]`, merge into
one entry. Response tokens concatenated with the observation-delta
(loss-masked to 0) between turns; per-token fields (`loss_masks`,
`rewards`, `rollout_logprobs`) align accordingly; per-turn fields
(`stop_reason`, `is_last_step`, `trajectory_id`) take the last turn's
value.
- `RayPPOTrainer.postprocess_generator_output` calls
`merge_stepwise_output` when `generator.merge_stepwise_output=true`,
updates `uids` from the merged `trajectory_ids`, and logs
`generate/num_seq_{before,after}_merge`.
- Since `uids` may need to be modified, update the signature of
`postprocess_generator_output` to return both `generator_output` and
`uids`, changing various caller places
- New `generator.merge_stepwise_output` config flag (default false).
- `examples/train/search/run_search.sh` accepts `MERGE_STEPWISE=true`
env var to pass the flag through.
- 16 CPU-only unit tests in `tests/train/test_merge_stepwise_output.py`
cover the three merge cases, partial merges, prefix mismatches,
single-turn passthrough, per-trajectory scalar rewards, and
required-field asserts.

## Test plan

- [x] `pytest tests/train/test_merge_stepwise_output.py` — 16 passed
- [x] `pytest tests/train/test_trainer_utils.py
tests/train/test_prompt_mini_batch.py` — 58 passed (existing tests
unaffected)
- [x] E2E: Search-R1 step-wise GRPO run on Qwen2.5-3B-Instruct, 8×H100,
`MERGE_STEPWISE=true`.

### Curves

With pricesly the same setup as #1529 , we do:

```bash
MERGE_STEPWISE=true USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \
  generator.inference_engine.num_engines=8 \
  generator.inference_engine.tensor_parallel_size=1
```
See PR description for more.

Co-authored-by: Deep Sheth
[deepsheth3@users.noreply.github.com](mailto:deepsheth3@users.noreply.github.com)
CharlieFRuan added a commit that referenced this pull request Apr 20, 2026
…1538)

## Summary

This PR implements prefix-aware merging for step-wise training, guarded
by a flag `cfg.generator.merge_stepwise_output` that defaults to False.
During step-wise training, within a trajectory, when consecutive steps
share the same prefix (i.e. no re-tokenization drift or context
management like thinking token stripping), we collapse into a single
`GeneratorOutput` entry. This can reduce the O(T²) training cost
introduced by step-wise (T being number of turns).

- `merge_stepwise_output()` in `skyrl/train/utils/trainer_utils.py`
implements greedy merging: for consecutive turns in the same trajectory
where `prompt[i] + response[i]` is a prefix of `prompt[i+1]`, merge into
one entry. Response tokens concatenated with the observation-delta
(loss-masked to 0) between turns; per-token fields (`loss_masks`,
`rewards`, `rollout_logprobs`) align accordingly; per-turn fields
(`stop_reason`, `is_last_step`, `trajectory_id`) take the last turn's
value.
- `RayPPOTrainer.postprocess_generator_output` calls
`merge_stepwise_output` when `generator.merge_stepwise_output=true`,
updates `uids` from the merged `trajectory_ids`, and logs
`generate/num_seq_{before,after}_merge`.
- Since `uids` may need to be modified, update the signature of
`postprocess_generator_output` to return both `generator_output` and
`uids`, changing various caller places
- New `generator.merge_stepwise_output` config flag (default false).
- `examples/train/search/run_search.sh` accepts `MERGE_STEPWISE=true`
env var to pass the flag through.
- 16 CPU-only unit tests in `tests/train/test_merge_stepwise_output.py`
cover the three merge cases, partial merges, prefix mismatches,
single-turn passthrough, per-trajectory scalar rewards, and
required-field asserts.

## Test plan

- [x] `pytest tests/train/test_merge_stepwise_output.py` — 16 passed
- [x] `pytest tests/train/test_trainer_utils.py
tests/train/test_prompt_mini_batch.py` — 58 passed (existing tests
unaffected)
- [x] E2E: Search-R1 step-wise GRPO run on Qwen2.5-3B-Instruct, 8×H100,
`MERGE_STEPWISE=true`.

### Curves

With pricesly the same setup as #1529 , we do:

```bash
MERGE_STEPWISE=true USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \
  generator.inference_engine.num_engines=8 \
  generator.inference_engine.tensor_parallel_size=1
```

See PR description for more.

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Deep Sheth <deepsheth3@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant